Integer Arithm Only Quant and Train
Paper
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
논문을 바탕으로 작성하였습니다.
논문을 바탕으로 작성하였습니다.Introduction
최근 SOTA를 달성하는 CNN 모델 구조는 복잡도와 연산 효율을 고려하지 않고 높은 정확도를 가지도록 발전해왔어요.
이는 스마트폰, 드론과 같은 자원이 한정된 모바일 플랫폼에 적합하지 않아요.
그래서 최소한의 정확도 하락으로 모델 사이즈와 추론 속도를 감소시키는 방법이 주목받고 있어요.
간단하게 생각하면 이를 위해 두 가지 접근법이 있네요.
- MobileNet, SqueezeNet, ShuffleNet, 그리고 DenseNet과 같은 메모리, 연산 효율적인 모델 구조를 디자인 하는 방법.
이 방법은 합리적인 측정 기준이 정해지지 않았어요.
AlexNet, VGG, 그리고 GoogleNet과 같은 모델은 이미 적은 성능 향상을 위해 과도한 파라미터를 사용하기 때문에, 모델 사이즈를 줄이는 것은 어렵지 않습니다.
이미 연산량과 정확도 사이 trade-off 관계를 가지는 MobileNet과 같은 모델을 quantization 하는 것이 의미있는 도전이겠죠?
- 모델의 weight, activation 등을 float32 연상에서 더 낮은 비트로 quantization 하는 방법.
많은 quantization 접근 방법은 실제 하드웨어에서 효율적인 향상을 보이지는 않더라구요.
weight만을 quantize하여 모델 사이즈를 줄이지만 연산 효율은 증가하지 않는다던가,
곱셈을 bit shift 연산으로 할 수 있게 하지만 하드웨어에서 별다른 이점이 되 지 않는다던가.
특히 곱셈은 큰 bit의 연산을 수행할 때만 비용이 많이 든다.
따라서 activation과 weight를 모두 quantization 하면 낮은 비트의 연산으로 곱셈을 피하지 않아도 된다.
이 논문에서 MobileNet을 이용하여 위에서 언급한 내용들을 다루는 quantization 방법을 다루겠습니다.
Quantization Scheme
Integer-arithmetic-only matrix multiplication
우리는 float32로 학습하고 integer만을 사용한 연산으로 추론하는 quantization 계획을 세울 것입니다.
실제값 을 int8로 quantize한다고 생각합시다.
여기서 는 quantized된 값, 와 는 각각 scale, zero point로 quantization paramter입니다.
우리는 각각의 tensor(activation, weight 등)에 대해서 각각의 quantization parameter를 사용할 거에요.
우리는 형태의 행렬인 의 곱 인 를 위와 같은 식을 이용해 표현한다고 생각 해봅시다.
는 의 행 열이라 정의하고 식을 다시 정리 해보죠.
그렇다면 아래처럼 표현할 수 있겠죠?
음 정리를 조금 해볼까요?
여기서 우리는 경험적으로 M이 (0, 1)에 분포하는 것을 알 수 있어요.
그렇다면 범위인 fixed-point 에 대해 을 표현해 볼게요.
사실 이 부분이 처음에는 이해가 안되었는데 LLM 모델 다뤄보니 보통 양자화 bit 범위보다 tensor의 범위가 더 작기 떄문에 scale 값이 1보다 작은 경우가 대부분이더라.그렇다면 범위인 fixed-point 에 대해 을 표현해 볼게요.
이제 우리는 bit shift 연산만으로 를 으로 만들 수 있게 되었어요.
여기서 결국 S3는 모르는데 어떻게 M을 계산하는 것인가?Efficient handling of zero-points
자 다시 돌아가서 4번 식을 봅시다.
4번 식을 약간 변형하면 아래와 같이 표현할 수 있어요.
여기서 와 는 각각 2번째 행렬의 k번째 열의 합, 1번째 행렬의 i번째 행의 합이네요.
우리는 크기의 두 행렬의 연산을 구하고 있는 중이잖아요?
그럼 각각 N번의 덧셈, N번의 덧셈이 사용되는 거네요.
전체 행렬에 대한 연산을 구한다고 하면 0, 1, ..., N 번째 열의 합이니 번 연산, 0, 1, ..., N번째 행의 합이니 번 연산, 둘이 합쳐 번의 연산 이 발생한다고 할 수 있죠.
그렇다면 이제 상수를 제외하고는 남은 친구가 하나 있네요.
두 행렬의 곱연산을 생각해 보면 하나의 결과 행렬 요소를 구하기 위해 번의 곱셈과 번의 덧셈, 즉 번의 연산이 발생합니다.
이 행위를 결과 행렬 요소의 수인 번 반복하여 총 번의 연산이 필요하죠.
그런데 왜 이렇게 연산을 풀어서 진행할까요?
4번 식처럼 zero point를 빼준 상태로 연산을 진행하게 되면 큰 값과 큰 값의 덧셈, 곱셈의 형태로 표현되어 연산 결과를 저장하기 위한 더 많은 비트가 필요할 수 있습니다.
그래서 우리는 전체 연산에서 곱셈이 차지하는 비중을 줄임으로써 비트 확장의 필요성을 감소시킵니다.
이는 완벽하게 비트 확장을 배제하는 것이 아닌, 연산 과정에서 비트 확장을 최소화 하려는 과정이라고 생각하면 이해가 더 쉬울 것 같네요.
~~
겱국에 더 적은 연산량을 가지게 된다는 것 같은데 맞는건가?~~아 zero point에 상관없이 작은 수의 형태로 연산을 하려고! 그래서 비트 더 많이 안쓰게 한다고!하지만 연산을 나누어 수행하는 만큼 연산 시간은 더 길어지는 것이 아닌가? 아니면 적은 비트로 연산하니 더 빠른가?Implementation of a typical fused layer
우리는 앞서 두 행렬간 곱연산에 대해 quantization 방법을 생각해 보았어요.
실제 모델에서는 두 행렬 activation과 weight의 곱연산 외에도 bias나 activation function을 가지고 있는데 이를 행렬곱과 직접 결합하는 형태로 만들어 볼 것이에요.
훈련중에 사용되는 연산을 추론 코드에서 재현하는 과정이라 할 수 있겠네요.
을 weight, 를 activation이라 가정하고, 둘 다 uint8 형태를 가진다고 할게요.
우리는 32-bit accumulator를 사용할 거에요.
그 이유는 조금 이따 설명하게요.
우리는 여기에 quantized된 bias를 더해줄 거에요.
bias는 int32 형태로 아래와 같은 quantization parameter를 가집니다.
bias는 output activation에 더해지기 때문에 bias-vector의 quantization error는 모델 전체를 편향되게 만드는 경향이 있어요.
그래서 bias의 quantization error를 줄이는게 중요하답니다.
quantization error를 최소화 하기 위해 bias는 32-bit로 quantize 해줄 거에요.
그렇기에 우리는 int32 accumulator를 사용하는 겁니다.
자 이제 우리는 할일이 두 가지밖에 안남았어요.
하나는 값을 8-bit에 맞게 scale down을 하는 것이고요,
다른 하나는 uint8로 cast down하고 activation function을 적용하여 8-bit output을 만드는 거에요.
down-scaling은 식 7번을 이용하여 [0, 255]범위에 꽉 들어 맞도록 값을 조절해야 합니다.
실제로 적용해봐야 이해가 될 듯ReLU, ReLU6와 같은 activation function은 거의 clamp를 하지 않기 때문에 layer와 합칠 필요가 없어요.
해야할 일은 down-scaling된 값을 최종적으로 uint8로 cast down 해주어 최종적으로 uint8의 결과를 내면됩 니다.
이 부분도 구현하며 알아보자!Training with simulated quantization
일반적으로 FP로 학습을 진행한 후 결과 weight에 대해 quantize를 진행하는 접근을 합니다.
(일부는 quantization 후 fine-tuning을 하기도 함!)
이러한 방법은 큰 모델에 대해서 상당한 정확성을 가지지만 작은 모델에서 정확도가 크게 하락해요.
그 문제점은 아래와 같아요.
- 채널간 weight의 범위가 너무 큰 차이를 가지는 경우
- outlier weight가 나머지 weight에 대한 정확도를 크게 하락시키는 경우
동일한 layer의 모든 채널에 대해 동일한 resolution으로 quantize를 하기에 위와 같은 상황에서 정확도 하락이 발생하는 거죠.
그래서 우리는 backpropagation은 기존과 동일하게 FP로 조절되지만 forward propagation에서 quantization의 영향을 시뮬레이션 하는 접근을 사용할 것입니다.
forward를 inference시와 동일하게 하지만 FP의 형태로 구현한다는 소리입니다!
이것이 fake quantization을 말하는 것인듯!weight는 input이 들어오기 전부터 이미 quantized된 상태에요.
만약 batch norm이 사용되는 레이어라면, weight를 quantize하기 전에 batch norm의 파라미터를 weight와 합칠 것입니다.
activation은 추론시에 quantized 됩니다. layer의 output에 activation function이 적용된 후 와 같이 말이죠!
은 실제 값, 는 quantization range, 는 반올림, 그리고 입니다. (8bit quantization)
Learning quantization ranges
quantization range는 weight와 activation에 따라 다르게 다룰 거예요.
weight에서 입니다.
이때 우리는 -128을 사용하지 않는 범위로만 사용할 거예요.
그 이유는 뒤에서 설명하겠습니다.
activation은 network에 따라 다르기 때문에, training 동안에 그 범위를 수집할 거예요.
그리고 exponential moving averages (EMA)를 사용하여 집계합니다.
물론 훈련 초반에는 범위가 급격하게 변하기 때문에, 안정적인 단계에 들어갈 때 까지 activation quantization을 사용하지 않을 거예요.
우리는 quantization을 진행하면 값을 가장 가까운 숫자에 맵핑하는 방식으로 진행하게 될 텐데,
0은 값이 없음 등과 같은 특별한 의미를 가지는 친구이기 때문에 정확하게 표현해줄 필요가 있어요.
그래서 수집한 [a; b]의 범위를 0.0이 어떠한 값에 근사되어 맵핑하는 것이 아닌, 정확하게 그 값을 가지도록 조정해줄 것입니다.